{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 推荐系统组队学习 第一次打卡：赛题理解+baseline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 赛题要求：\n",
    "根据用户历史浏览点击新闻文章的数据信息预测用户未来的点击行为， 即用户的最后一次点击的新闻文章\n",
    "\n",
    "### 数据特征\n",
    "该数据来自某新闻APP平台的用户交互数据，包括30万用户，近300万次点击，共36万多篇不同的新闻文章，同时每篇新闻文章有对应的embedding向量表示。为了保证比赛的公平性，从中抽取20万用户的点击日志数据作为训练集，5万用户的点击日志数据作为测试集A，5万用户的点击日志数据作为测试集B。\n",
    "\n",
    "### 评价方式:\n",
    "针对每个用户，都会将他最有可能点击的五篇文章列出来；而真实结果只有一个，对于预测出的五个结果，会有两类可能：\n",
    "\n",
    "(1)五个全部没有预测对,则得到零分\n",
    "\n",
    "(2)第k个预测对了，则得到1/k分\n",
    "\n",
    "因此希望预测出的结果中，命中的结果尽量靠前，此时得分比较高。\n",
    "\n",
    "得分的数学公式：\n",
    "\n",
    "$$score(user) = \\sum_{k=1}^5 \\frac{s(user, k)}{k}$$\n",
    "\n",
    "### 赛题理解\n",
    "\n",
    "目标：根据用户历史浏览点击新闻的数据信息预测用户最后一次点击的新闻文章。\n",
    "\n",
    "数据特征：基于了真实的业务场景的用户的点击日志\n",
    "\n",
    "思路：把该预测问题转成一个监督学习的问题(特征+标签)，然后我们才能进行ML，DL等建模预测\n",
    "\n",
    "36万篇文章中预测某一篇的话我们首先可能会想到这可能是一个多分类的问题(36万类里面选1)， 但是如此庞大的分类问题， 我们做起来可能比较困难， 那么能不能转化一下？ 既然是要预测最后一次点击的文章， 那么如果我们能预测出某个用户最后一次对于某一篇文章会进行点击的概率， 是不是就间接性的解决了这个问题呢？概率最大的那篇文章不就是用户最后一次可能点击的新闻文章吗？ 这样就把原问题变成了一个点击率预测的问题(用户, 文章) --> 点击的概率(软分类)， 而这个问题， 就是我们所熟悉的监督学习领域分类问题了， 这样我们后面建模的时候， 对于模型的选择就基本上有大致方向了，比如最简单的逻辑回归模型。\n",
    "\n",
    "问题：如何转成监督学习问题？ 训练集和测试集怎么制作？ 我们又能利用哪些特征？ 我们又可以尝试哪些模型？ 面对36万篇文章， 20多万用户的推荐， 我们又有哪些策略来缩减问题的规模？如何进行最后的预测？\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## baseline分析\n",
    "\n",
    "baseline使用了协同过滤，根据用户之前的喜好以及其他兴趣相近的用户的选择来给用户推荐物品\n",
    "\n",
    "使用了基于物品的协同过滤算法"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 下载数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-11-24 21:05:55--  http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531842/articles.csv\n",
      "Resolving tianchi-competition.oss-cn-hangzhou.aliyuncs.com (tianchi-competition.oss-cn-hangzhou.aliyuncs.com)... 118.31.232.194\n",
      "Connecting to tianchi-competition.oss-cn-hangzhou.aliyuncs.com (tianchi-competition.oss-cn-hangzhou.aliyuncs.com)|118.31.232.194|:80... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 10365815 (9.9M) [text/csv]\n",
      "Saving to: ‘articles.csv’\n",
      "\n",
      "100%[======================================>] 10,365,815  --.-K/s   in 0.1s    \n",
      "\n",
      "2020-11-24 21:05:55 (66.2 MB/s) - ‘articles.csv’ saved [10365815/10365815]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!wget http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531842/articles.csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-11-24 21:06:16--  http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531842/articles_emb.csv\n",
      "Resolving tianchi-competition.oss-cn-hangzhou.aliyuncs.com (tianchi-competition.oss-cn-hangzhou.aliyuncs.com)... 118.31.232.194\n",
      "Connecting to tianchi-competition.oss-cn-hangzhou.aliyuncs.com (tianchi-competition.oss-cn-hangzhou.aliyuncs.com)|118.31.232.194|:80... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 1020419923 (973M) [text/csv]\n",
      "Saving to: ‘articles_emb.csv’\n",
      "\n",
      "100%[====================================>] 1,020,419,923 11.9MB/s   in 79s    \n",
      "\n",
      "2020-11-24 21:07:35 (12.4 MB/s) - ‘articles_emb.csv’ saved [1020419923/1020419923]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!wget http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531842/articles_emb.csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-11-24 21:07:40--  http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531842/testA_click_log.csv\n",
      "Resolving tianchi-competition.oss-cn-hangzhou.aliyuncs.com (tianchi-competition.oss-cn-hangzhou.aliyuncs.com)... 118.31.232.194\n",
      "Connecting to tianchi-competition.oss-cn-hangzhou.aliyuncs.com (tianchi-competition.oss-cn-hangzhou.aliyuncs.com)|118.31.232.194|:80... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 21459692 (20M) [text/csv]\n",
      "Saving to: ‘testA_click_log.csv’\n",
      "\n",
      "100%[======================================>] 21,459,692  15.6MB/s   in 1.3s   \n",
      "\n",
      "2020-11-24 21:07:41 (15.6 MB/s) - ‘testA_click_log.csv’ saved [21459692/21459692]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!wget http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531842/testA_click_log.csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-11-24 21:07:57--  http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531842/train_click_log.csv\n",
      "Resolving tianchi-competition.oss-cn-hangzhou.aliyuncs.com (tianchi-competition.oss-cn-hangzhou.aliyuncs.com)... 118.31.232.194\n",
      "Connecting to tianchi-competition.oss-cn-hangzhou.aliyuncs.com (tianchi-competition.oss-cn-hangzhou.aliyuncs.com)|118.31.232.194|:80... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 45609590 (43M) [text/csv]\n",
      "Saving to: ‘train_click_log.csv’\n",
      "\n",
      "100%[======================================>] 45,609,590  16.3MB/s   in 2.7s   \n",
      "\n",
      "2020-11-24 21:08:00 (16.3 MB/s) - ‘train_click_log.csv’ saved [45609590/45609590]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!wget http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531842/train_click_log.csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time, math, os\n",
    "from tqdm import tqdm #进度条\n",
    "import gc\n",
    "import pickle\n",
    "import random\n",
    "from datetime import datetime\n",
    "from operator import itemgetter\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import warnings\n",
    "import collections\n",
    "from collections import defaultdict\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mv train_click_log.csv articles.csv articles_emb.csv testA_click_log.csv data_raw"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = './data_raw/'\n",
    "save_path = './tmp_results/'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 下面使用一个函数，通过降低数据不需要的精度，来节省内存"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 节约内存的一个标配函数\n",
    "def reduce_mem(df):\n",
    "    starttime = time.time()\n",
    "    numerics = ['int16', 'int32', 'int64', 'float16', 'float32', 'float64']\n",
    "    start_mem = df.memory_usage().sum() / 1024**2\n",
    "    for col in df.columns:\n",
    "        col_type = df[col].dtypes\n",
    "        if col_type in numerics:\n",
    "            c_min = df[col].min()\n",
    "            c_max = df[col].max()\n",
    "            if pd.isnull(c_min) or pd.isnull(c_max):\n",
    "                continue\n",
    "            if str(col_type)[:3] == 'int':\n",
    "                if c_min > np.iinfo(np.int8).min and c_max < np.iinfo(np.int8).max:\n",
    "                    df[col] = df[col].astype(np.int8)\n",
    "                elif c_min > np.iinfo(np.int16).min and c_max < np.iinfo(np.int16).max:\n",
    "                    df[col] = df[col].astype(np.int16)\n",
    "                elif c_min > np.iinfo(np.int32).min and c_max < np.iinfo(np.int32).max:\n",
    "                    df[col] = df[col].astype(np.int32)\n",
    "                elif c_min > np.iinfo(np.int64).min and c_max < np.iinfo(np.int64).max:\n",
    "                    df[col] = df[col].astype(np.int64)\n",
    "            else:\n",
    "                if c_min > np.finfo(np.float16).min and c_max < np.finfo(np.float16).max:\n",
    "                    df[col] = df[col].astype(np.float16)\n",
    "                elif c_min > np.finfo(np.float32).min and c_max < np.finfo(np.float32).max:\n",
    "                    df[col] = df[col].astype(np.float32)\n",
    "                else:\n",
    "                    df[col] = df[col].astype(np.float64)\n",
    "    end_mem = df.memory_usage().sum() / 1024**2\n",
    "    print('-- Mem. usage decreased to {:5.2f} Mb ({:.1f}% reduction),time spend:{:2.2f} min'.format(end_mem,\n",
    "                                                                                                           100*(start_mem-end_mem)/start_mem,\n",
    "                                                                                                           (time.time()-starttime)/60))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 读取一部分或者全部的数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# debug模式：从训练集中划出一部分数据来调试代码\n",
    "def get_all_click_sample(data_path, sample_nums=10000):\n",
    "    \"\"\"\n",
    "        训练集中采样一部分数据调试\n",
    "        data_path: 原数据的存储路径\n",
    "        sample_nums: 采样数目（这里由于机器的内存限制，可以采样用户做）\n",
    "    \"\"\"\n",
    "    all_click = pd.read_csv(data_path + 'train_click_log.csv')\n",
    "    all_user_ids = all_click.user_id.unique()\n",
    "\n",
    "    sample_user_ids = np.random.choice(all_user_ids, size=sample_nums, replace=False) \n",
    "    all_click = all_click[all_click['user_id'].isin(sample_user_ids)]\n",
    "    \n",
    "    all_click = all_click.drop_duplicates((['user_id', 'click_article_id', 'click_timestamp']))\n",
    "    return all_click\n",
    "\n",
    "# 读取点击数据，这里分成线上和线下，如果是为了获取线上提交结果应该讲测试集中的点击数据合并到总的数据中\n",
    "# 如果是为了线下验证模型的有效性或者特征的有效性，可以只使用训练集\n",
    "def get_all_click_df(data_path='./data_raw/', offline=True):\n",
    "    if offline:\n",
    "        all_click = pd.read_csv(data_path + 'train_click_log.csv')\n",
    "    else:\n",
    "        trn_click = pd.read_csv(data_path + 'train_click_log.csv')\n",
    "        tst_click = pd.read_csv(data_path + 'testA_click_log.csv')\n",
    "\n",
    "        all_click = trn_click.append(tst_click)\n",
    "    \n",
    "    all_click = all_click.drop_duplicates((['user_id', 'click_article_id', 'click_timestamp']))\n",
    "    return all_click"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 全量训练集\n",
    "all_click_df = get_all_click_df(offline=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 获取 用户 - 文章 - 点击时间字典\n",
    "\n",
    "其核心是把user_id, click_article, click_timestamp粘在一起成为item_time_list\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 根据点击时间获取用户的点击文章序列   {user1: [(item1, time1), (item2, time2)..]...}\n",
    "def get_user_item_time(click_df):\n",
    "    \n",
    "    click_df = click_df.sort_values('click_timestamp')\n",
    "    \n",
    "    def make_item_time_pair(df):\n",
    "        return list(zip(df['click_article_id'], df['click_timestamp']))\n",
    "    \n",
    "    user_item_time_df = click_df.groupby('user_id')['click_article_id', 'click_timestamp'].apply(lambda x: make_item_time_pair(x))\\\n",
    "                                                            .reset_index().rename(columns={0: 'item_time_list'})\n",
    "    user_item_time_dict = dict(zip(user_item_time_df['user_id'], user_item_time_df['item_time_list']))\n",
    "    \n",
    "    return user_item_time_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 获取点击最多的Topk个文章"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取近期点击最多的文章\n",
    "def get_item_topk_click(click_df, k):\n",
    "    topk_click = click_df['click_article_id'].value_counts().index[:k]\n",
    "    return topk_click"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### itemCF的物品相似度计算\n",
    "\n",
    "计算相似度：\n",
    "```\n",
    "\n",
    "i2i_sim[i][j] += 1 / math.log(len(item_time_list) + 1)\n",
    "\n",
    "i2i_sim_[i][j] = wij / math.sqrt(item_cnt[i] * item_cnt[j])\n",
    "```\n",
    "\n",
    "也就是$w_{ij}$对第i个item和第j个item的"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def itemcf_sim(df):\n",
    "    \"\"\"\n",
    "        文章与文章之间的相似性矩阵计算\n",
    "        :param df: 数据表\n",
    "        :item_created_time_dict:  文章创建时间的字典\n",
    "        return : 文章与文章的相似性矩阵\n",
    "        思路: 基于物品的协同过滤(详细请参考上一期推荐系统基础的组队学习)， 在多路召回部分会加上关联规则的召回策略\n",
    "    \"\"\"\n",
    "    \n",
    "    user_item_time_dict = get_user_item_time(df)\n",
    "    \n",
    "    # 计算物品相似度\n",
    "    i2i_sim = {}\n",
    "    item_cnt = defaultdict(int)\n",
    "    for user, item_time_list in tqdm(user_item_time_dict.items()):\n",
    "        # 在基于商品的协同过滤优化的时候可以考虑时间因素\n",
    "        for i, i_click_time in item_time_list:\n",
    "            item_cnt[i] += 1\n",
    "            i2i_sim.setdefault(i, {})\n",
    "            for j, j_click_time in item_time_list:\n",
    "                if(i == j):\n",
    "                    continue\n",
    "                i2i_sim[i].setdefault(j, 0)\n",
    "                \n",
    "                i2i_sim[i][j] += 1 / math.log(len(item_time_list) + 1)\n",
    "                \n",
    "    i2i_sim_ = i2i_sim.copy()\n",
    "    for i, related_items in i2i_sim.items():\n",
    "        for j, wij in related_items.items():\n",
    "            i2i_sim_[i][j] = wij / math.sqrt(item_cnt[i] * item_cnt[j])\n",
    "    \n",
    "    # 将得到的相似性矩阵保存到本地\n",
    "    pickle.dump(i2i_sim_, open(save_path + 'itemcf_i2i_sim.pkl', 'wb'))\n",
    "    \n",
    "    return i2i_sim_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 250000/250000 [00:23<00:00, 10564.97it/s]\n"
     ]
    }
   ],
   "source": [
    "i2i_sim = itemcf_sim(all_click_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### itemCF 的文章推荐"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 基于商品的召回i2i\n",
    "def item_based_recommend(user_id, user_item_time_dict, i2i_sim, sim_item_topk, recall_item_num, item_topk_click):\n",
    "    \"\"\"\n",
    "        基于文章协同过滤的召回\n",
    "        :param user_id: 用户id\n",
    "        :param user_item_time_dict: 字典, 根据点击时间获取用户的点击文章序列{user1: [(item1, time1), (item2, time2)..]...}\n",
    "        :param i2i_sim: 字典，文章相似性矩阵\n",
    "        :param sim_item_topk: 整数， 选择与当前文章最相似的前k篇文章\n",
    "        :param recall_item_num: 整数， 最后的召回文章数量\n",
    "        :param item_topk_click: 列表，点击次数最多的文章列表，用户召回补全        \n",
    "        return: 召回的文章列表 [item1:score1, item2: score2...]\n",
    "        注意: 基于物品的协同过滤(详细请参考上一期推荐系统基础的组队学习)， 在多路召回部分会加上关联规则的召回策略\n",
    "    \"\"\"\n",
    "    \n",
    "    # 获取用户历史交互的文章\n",
    "    user_hist_items = user_item_time_dict[user_id] # 注意，此时获取得到的是一个元组列表，需要将里面的user_id提取出来\n",
    "    user_hist_items_ = {user_id for user_id, _ in user_hist_items}\n",
    "\n",
    "    item_rank = {}\n",
    "    for loc, (i, click_time) in enumerate(user_hist_items):\n",
    "        for j, wij in sorted(i2i_sim[i].items(), key=lambda x: x[1], reverse=True)[:sim_item_topk]:\n",
    "            if j  in user_hist_items_:\n",
    "                continue\n",
    "                \n",
    "            item_rank.setdefault(j, 0)\n",
    "            item_rank[j] +=  wij\n",
    "    \n",
    "    # 不足10个，用热门商品补全\n",
    "    if len(item_rank) < recall_item_num:\n",
    "        for i, item in enumerate(item_topk_click):\n",
    "            if item in item_rank.items(): # 填充的item应该不在原来的列表中\n",
    "                continue\n",
    "            item_rank[item] = - i - 100 # 随便给个负数就行\n",
    "            if len(item_rank) == recall_item_num:\n",
    "                break\n",
    "    \n",
    "    item_rank = sorted(item_rank.items(), key=lambda x: x[1], reverse=True)[:recall_item_num]\n",
    "        \n",
    "    return item_rank"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 给每个用户根据物品的协同过滤推荐文章"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 250000/250000 [46:05<00:00, 90.41it/s] \n"
     ]
    }
   ],
   "source": [
    "# 定义\n",
    "user_recall_items_dict = collections.defaultdict(dict)\n",
    "\n",
    "# 获取 用户 - 文章 - 点击时间的字典\n",
    "user_item_time_dict = get_user_item_time(all_click_df)\n",
    "\n",
    "# 去取文章相似度\n",
    "i2i_sim = pickle.load(open(save_path + 'itemcf_i2i_sim.pkl', 'rb'))\n",
    "\n",
    "# 相似文章的数量\n",
    "sim_item_topk = 10\n",
    "\n",
    "# 召回文章数量\n",
    "recall_item_num = 10\n",
    "\n",
    "# 用户热度补全\n",
    "item_topk_click = get_item_topk_click(all_click_df, k=50)\n",
    "\n",
    "for user in tqdm(all_click_df['user_id'].unique()):\n",
    "    user_recall_items_dict[user] = item_based_recommend(user, user_item_time_dict, i2i_sim, \n",
    "                                                        sim_item_topk, recall_item_num, item_topk_click)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 召回字典转换成df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 250000/250000 [00:04<00:00, 52195.83it/s]\n"
     ]
    }
   ],
   "source": [
    "# 将字典的形式转换成df\n",
    "user_item_score_list = []\n",
    "\n",
    "for user, items in tqdm(user_recall_items_dict.items()):\n",
    "    for item, score in items:\n",
    "        user_item_score_list.append([user, item, score])\n",
    "\n",
    "recall_df = pd.DataFrame(user_item_score_list, columns=['user_id', 'click_article_id', 'pred_score'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 生成提交文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 生成提交文件\n",
    "def submit(recall_df, topk=5, model_name=None):\n",
    "    recall_df = recall_df.sort_values(by=['user_id', 'pred_score'])\n",
    "    recall_df['rank'] = recall_df.groupby(['user_id'])['pred_score'].rank(ascending=False, method='first')\n",
    "    \n",
    "    # 判断是不是每个用户都有5篇文章及以上\n",
    "    tmp = recall_df.groupby('user_id').apply(lambda x: x['rank'].max())\n",
    "    assert tmp.min() >= topk\n",
    "    \n",
    "    del recall_df['pred_score']\n",
    "    submit = recall_df[recall_df['rank'] <= topk].set_index(['user_id', 'rank']).unstack(-1).reset_index()\n",
    "    \n",
    "    submit.columns = [int(col) if isinstance(col, int) else col for col in submit.columns.droplevel(0)]\n",
    "    # 按照提交格式定义列名\n",
    "    submit = submit.rename(columns={'': 'user_id', 1: 'article_1', 2: 'article_2', \n",
    "                                                  3: 'article_3', 4: 'article_4', 5: 'article_5'})\n",
    "    \n",
    "    save_name = save_path + model_name + '_' + datetime.today().strftime('%m-%d') + '.csv'\n",
    "    submit.to_csv(save_name, index=False, header=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取测试集\n",
    "tst_click = pd.read_csv(data_path + 'testA_click_log.csv')\n",
    "tst_users = tst_click['user_id'].unique()\n",
    "\n",
    "# 从所有的召回数据中将测试集中的用户选出来\n",
    "tst_recall = recall_df[recall_df['user_id'].isin(tst_users)]\n",
    "\n",
    "# 生成提交文件\n",
    "submit(tst_recall, topk=5, model_name='itemcf_baseline')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 总结\n",
    "\n",
    "理解赛题：从直观上对问题进行梳理， 分析问题的目标，到底要让做什么事情, 这个非常重要\n",
    "\n",
    "理解数据：对赛题数据有一个初步了解，知道和任务相关的数据字段和数据字段的类型， 数据之间的内在关联等，大体梳理一下哪些数据会对我们解决问题非常有用，方便后面我们的数据分析和特征工程。\n",
    "\n",
    "理解评估指标：评估指标是检验我们提出的方法，我们给出结果好坏的标准，只有正确的理解了评估指标，我们才能进行更好的训练模型，更好的进行预测。此外，很多情况下，线上验证是有一定的时间和次数限制的，所以在比赛中构建一个合理的本地的验证集和验证的评价指标是很关键的步骤，能有效的节省很多时间。 不同的指标对于同样的预测结果是具有误差敏感的差异性的所以不同的评价指标会影响后续一些预测的侧重点。\n",
    "\n",
    "在对于赛题有了一定的了解后，分析清楚了问题的类型性质和对于数据理解 的这一基础上，我们可以梳理一个解决赛题的一个大题思路和框架\n",
    "\n",
    "我们至少要有一些相应的理解分析，比如这题的难点可能在哪里，关键点可能在哪里，哪些地方可以挖掘更好的特征.\n",
    "\n",
    "用什么样得线下验证方式更为稳定，出现了过拟合或者其他问题，估摸可以用什么方法去解决这些问题"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
